#! /usr/bin/env python

#  zbtestcrypt - test encryption by decrypting and re-encrypting
# 
#  Adam Laurie <adam@aperturelabs.com>
#  http://www.aperturelabs.com
# 

import sys
import argparse

from scapy.all import *
from killerbee import *
from killerbee.scapy_extensions import *

from rangeparser import RangeParser

# parser for packet ranges
rangeparser= RangeParser()

# encrypt/decrypt packet with appropriate key
def crypt(packet, data= None, encrypt= False):
    global link_key
    global nwk_key

    if packet[ZigbeeSecurityHeader].key_type == 0:
        if link_key == None:
            return False, 'FAILED: No LINK_KEY provided'
        key= link_key
    elif packet[ZigbeeSecurityHeader].key_type == 1:
        if nwk_key == None:
            return False, 'FAILED: No NWK_KEY provided'
        key= nwk_key
    else:
        return False, 'Unknown KEY type'

    if encrypt:
        return True, kbencrypt(packet, data, key, verbose= args.verbose)
    else:
        return True, kbdecrypt(packet, key, verbose= args.verbose)

if __name__ == '__main__':
    # Command-line arguments
    parser = argparse.ArgumentParser(description="zbtestcrypt: \
        Decrypt, re-encrypt and optionally transmit packets for testing and tool development.")
    parser.add_argument('-b', '--begin', action='store', type=int, dest='begin', default=0,
        help='begin processing at packet #')
    parser.add_argument('-c', '--channel', action='store', type=int, default=11,
        help='tx/rx on given channel (default 11)')
    parser.add_argument('-d', '--delay', action='store', type=float, default=1.0,
        help='if tx, wait given seconds between packet injections (default 1.0)')
    parser.add_argument('-D', action='store_true', dest='showdev',
        help='list KillerBee devices')
    parser.add_argument('-e', '--end', action='store', type=int, dest='end', default=0,
        help='end processing at packet #')
    parser.add_argument('-i', '--interface', action='store', type=str, default=None,
        help='provide the USB ID or Serial Device Path to use that device')
    parser.add_argument('-k', '--network_key', action='store', type=str, default=None,
        help='provide the NWK_KEY in HEX')
    parser.add_argument('-l', '--link_key', action='store', type=str, default=None,
        help='provide the LINK_KEY in HEX')
    parser.add_argument('-p', '--packet_range', action='store', type=str, default=None,
        help='process range of packets - e.g. "1,5,12-19,42"')
    parser.add_argument('-r', '--pcapfile', action='store', default=None,
        help='pcap file to test')
    parser.add_argument('-s', '--subghz_page', action='store', type=int, default=0,
        help='tx/rx on given SubGHz page (default disabled)')
    parser.add_argument('-R', '--dsnafile', action='store', default=None,
        help='Daintree SNA file to test')
    parser.add_argument('-S', '--search_keys', action='store_true', dest='search_keys',
        help='search for AppCommandPayload packets and extract transport keys')
    parser.add_argument('-T', '--transmit_original', action='store_true', dest='transmit_original',
        help='transmit original packets for independant capture/analysis')
    parser.add_argument('-t', '--transmit_crypted', action='store_true', dest='transmit_crypted',
        help='transmit re-encrypted packets for independant capture/analysis')
    parser.add_argument('-v', '--verbose', action='store', type=int, default=0,
        help='set debug verbosity')
    args = parser.parse_args()

    if args.showdev:
        show_dev()
        exit(False)

    print

    if args.pcapfile == None and args.dsnafile == None:
        print >>sys.stderr, "ERROR: Must specify a capture file using -r (libpcap) or -R (Daintree SNA)"
        exit(True)

    if args.pcapfile != None and args.dsnafile != None:
        print >>sys.stderr, "ERROR: Must specify only one of -r (libpcap) or -R (Daintree SNA)"
        exit(True)

    if args.network_key:
        try:
            nwk_key= args.network_key.replace(':','')
            nwk_key= nwk_key.decode('hex')
            print ' NWK_KEY:', nwk_key.encode('hex').upper()
        except:
            print >>sys.stderr, "ERROR: Invalid NWK_KEY"
            exit(True)
    else:
        nwk_key= None
    if nwk_key and len(nwk_key) != 16:
        print >>sys.stderr, "ERROR: Must specify 16 byte NWK_KEY in HEX"
        exit(True)

    if args.link_key:
        try:
            link_key= args.link_key.replace(':','')
            link_key= link_key.decode('hex')
            print 'LINK_KEY:', link_key.encode('hex').upper()
        except:
            print >>sys.stderr, "ERROR: Invalid LINK_KEY"
            exit(True)
    else:
        link_key= None
    if link_key and len(link_key) != 16:
        print >>sys.stderr, "ERROR: Must specify 16 byte LINK_KEY in HEX"
        exit(True)

    if (args.begin or args.end) and args.packet_range:
        print >>sys.stderr, "ERROR: Specify begin/end or packet_range"
        exit(True)

    if args.packet_range:
        packets= rangeparser.parse(args.packet_range)
    else:
        packets= None

    if args.pcapfile is not None:
        data= kbrdpcap(args.pcapfile)
        fname= args.pcapfile

    if args.dsnafile is not None:
        data= kbrddain(args.dsnafile)
        fname= args.dsnafile

    if args.search_keys:
        transport_keys= []

    if args.transmit_original or args.transmit_crypted:
        kb = KillerBee(device=args.interface)
        if not kb.is_valid_channel(args.channel, args.subghz_page):
            print >>sys.stderr, "ERROR: Invalid channel/subghz page"
            exit(True)

    print
    print '%d packets read from %s' % (len(data), fname)

    count= 0
    testcount= 0
    failed= 0
    passed= 0
    aps_failed= 0
    aps_passed= 0
    ext_adresses= {}

    for packet in data:
        count += 1
        if args.begin and count < args.begin:
            continue
        if args.end and count > args.end:
            break
        if packets and not count in packets:
            continue
        print
        print 'Packet:', count
        #print packet.layers(False) # debug
        print packet.layers() # debug
        print '  ', packet.summary()
        if args.verbose:
            print '  ', repr(packet)

        # check for transport keys
        if args.search_keys:
            if ZigbeeAppCommandPayload in packet:
                # TODO: all transport key types - need to find sample pcaps
                # (this will probably mean repeating after decryption but need to see examples to be sure!)
                f,v = packet[ZigbeeAppCommandPayload].getfield_and_val("cmd_identifier")
                cmd_identifier = f.i2repr(None, v)
                # SKKE_1
                if packet[ZigbeeAppCommandPayload].cmd_identifier == 1:
                    f,v = packet[ZigbeeAppCommandPayload].getfield_and_val("src_addr")
                    src_addr = f.i2repr(None, v)
                    f,v = packet[ZigbeeAppCommandPayload].getfield_and_val("dest_addr")
                    dest_addr = f.i2repr(None, v)
                    transport_keys.append([count, cmd_identifier, packet[ZigbeeAppCommandPayload].key, packet[ZigbeeAppCommandPayload].key_seqnum, src_addr, dest_addr])
                else:
                    transport_keys.append([count, cmd_identifier, packet[ZigbeeAppCommandPayload].data, None, None, None])

        # ignore frames with no encrypted payload
        if not ZigbeeNWK in packet or not ZigbeeSecurityHeader in packet:
            print '    no payload - skipping!'
            continue

        if ZigbeeSecurityHeader in packet:
            # capture extended source addresses for decryption
            # format is dict of Dot15d4Data.dest_panid with dict of source -> ext_source
            if ZigbeeSecurityHeader in packet and 'ext_src' in packet[ZigbeeNWK].fields:
                if not packet[Dot15d4Data].dest_panid in ext_adresses:
                    ext_adresses[packet[Dot15d4Data].dest_panid]= {}
                # only set if ext_address is bigger (more extended) than one we've seen before
                try:
                    if packet[ZigbeeNWK].ext_src > ext_adresses[packet[Dot15d4Data].dest_panid][packet[Dot15d4Data].src_addr]:
                        ext_adresses[packet[Dot15d4Data].dest_panid][packet[Dot15d4Data].src_addr] = packet[ZigbeeNWK].ext_src
                except:
                    print 'adding', packet[ZigbeeNWK].ext_src, packet[Dot15d4Data].src_addr
                    ext_adresses[packet[Dot15d4Data].dest_panid][packet[Dot15d4Data].src_addr] = packet[ZigbeeNWK].ext_src
            # only decrypt if APS layer is not already in the plain
            if not ZigbeeAppDataPayload in packet and packet[ZigbeeSecurityHeader].data != '':
                print '   decrypting NWK...'
                stat, decrypted = crypt(packet)
            else:
                stat, decrypted = True, copy.copy(packet)
            if stat:
                print '     ', decrypted.summary()
                if args.verbose:
                    print '     ', repr(decrypted)
                if args.verbose > 1:
                    print ' HEX:', str(decrypted).encode('hex')

                # If we have an APS layer and it's encrypted, build new packet with ZigbeeNWK so we can decrypt it
                if ZigbeeAppDataPayload in decrypted and ZigbeeSecurityHeader in decrypted and decrypted[ZigbeeSecurityHeader].data != '':
                    tmppkt = copy.copy(packet)
                    tmppkt[ZigbeeNWK].remove_payload()
                    tmppkt /= decrypted[ZigbeeAppDataPayload]
                    # if no ext_src we must provide it
                    if not tmppkt[ZigbeeNWK].ext_src:
                        if packet[ZigbeeSecurityHeader].source:
                            tmppkt[ZigbeeNWK].ext_src= packet[ZigbeeSecurityHeader].source
                        else:
                            try:
                                tmppkt[ZigbeeNWK].ext_src = ext_adresses[packet[Dot15d4Data].dest_panid][packet[Dot15d4Data].src_addr]
                            except:
                                pass
                            tmppkt[ZigbeeNWK].ext_src= packet[Dot15d4Data].src_addr
                    #tmppkt[ZigbeeNWK].ext_src = ext_adresses[packet[Dot15d4Data].dest_panid][packet[Dot15d4Data].src_addr]
                    print 'tmppkt[ZigbeeNWK].ext_src', tmppkt[ZigbeeNWK].ext_src
                    print '      decrypting APS...'
                    if args.verbose:
                        print '        ', repr(tmppkt)
                    if args.verbose > 1:
                        print ' HEX:', str(tmppkt).encode('hex')
                    stat, app_decrypt= crypt(tmppkt)
                    if stat:
                        aps_passed += 1
                        print '        ', app_decrypt.summary()
                        if args.verbose:
                            print '        ', repr(app_decrypt)
                        if args.verbose > 1:
                            print ' HEX:', str(app_decrypt).encode('hex')
                    else:
                        aps_failed += 1
                        print '         FAILED:', app_decrypt
    
                    print '      encrypting APS...'
                    tmppkt = copy.copy(packet)
                    tmppkt[ZigbeeNWK].remove_payload()
                    tmppkt /= decrypted[ZigbeeAppDataPayload]
                    # if no ext_src we must provide it
                    if not tmppkt[ZigbeeNWK].ext_src:
                        if packet[ZigbeeSecurityHeader].source:
                            tmppkt[ZigbeeNWK].ext_src= packet[ZigbeeSecurityHeader].source
                        else:
                            try:
                                tmppkt[ZigbeeNWK].ext_src = ext_adresses[packet[Dot15d4Data].dest_panid][packet[Dot15d4Data].src_addr]
                            except:
                                pass
                            #tmppkt[ZigbeeNWK].ext_src= packet[Dot15d4Data].src_addr
                        print 'set tmppkt[ZigbeeNWK].ext_src!', tmppkt[ZigbeeNWK].ext_src
                    tmppkt[ZigbeeNWK].ext_src = ext_adresses[packet[Dot15d4Data].dest_panid][packet[Dot15d4Data].src_addr]
                    if args.verbose:
                        print '          ', repr(tmppkt)
                    stat, recrypt = crypt(tmppkt, data= app_decrypt, encrypt= True)
                    if stat:
                        decrypted.data= recrypt.data
                        decrypted.mic= recrypt.mic
                        if args.verbose:
                            print '          ', repr(decrypted)
                    else:
                        print '          ', recrypt

                print '   encrypting NWK...'
                stat, newpkt= crypt(packet, data= decrypted, encrypt= True)
                if stat:
                    if args.verbose:
                        print '  ', repr(newpkt)
                else:
                    print '     ', newpkt
                if newpkt == packet:
                    print '   Packet match: OK'
                    passed += 1
                else:
                    print '   Packet match: FAILED!'
                    failed += 1

            # NWK decrypt failed
            else:
                print '     FAILED:', decrypted
                newpkt= None
                failed += 1

        if args.transmit_original:
            print '      transmitting ORIGINAL on channel', args.channel
            try:
                kb.inject(str(packet), args.channel, 1, args.delay, page= args.subghz_page)
                print '        OK'
                #debug
                print packet.command()
		for x in str(packet):
			print '0x%02x' % ord(x),
		print
            except:
                print '        FAILED!'
        if args.transmit_crypted and newpkt:
            print '      transmitting RE-ENCRYPTED on channel', args.channel
            try:
                kb.inject(str(packet), args.channel, 1, args.delay, page= args.subghz_page)
                print '        OK'
            except:
                print '        FAILED!'
        testcount += 1

    print ext_adresses

    print
    print '%d of %d packets tested' % (testcount, len(data))
    print '  NWK crypto passed:', passed
    print '  NWK crypto failed:', failed
    print '    APS crypto passed:', aps_passed
    print '    APS crypto failed:', aps_failed

    if args.search_keys:
        print
        print '%d key(s) found:' % len(transport_keys)
        for key in transport_keys:
            print '  packet: %d\n    type: %s\n     key: %s (%s)\n     seq: %d\n     src: %s\n    dest: %s' % (key[0], key[1], key[2].encode('hex'), key[2][::-1].encode('hex'), key[3], key[4], key[5])
    exit(False)
